﻿using System.Collections.Generic;
using System.Drawing;
using System.Linq;
using HAPI;
using HuginApiAddonsCS.Extensions;

namespace HuginApiAddonsCS.DomainExpander
{
    public class Expander
    {
        /// <summary>
        /// Loads a domain file and expands the time steps, loads the domain file
        /// before passing on to ExpanDomain(Domain inputDomain, DomainType dType, int numExpansions)
        /// for expansion then saves the resulting file
        /// </summary>
        /// <param name="inputNet">File path to the input net file</param>
        /// <param name="outputNet">File Path to the output net file</param>
        /// <param name="dType">The type of domain for expansion</param>
        /// <param name="numExpansions">Number of extra time steps to add to the domain</param>
        public void ExpandDomain(string inputNet, string outputNet, DomainType dType, int numExpansions)
        {
            ParseListener pl = new DefaultClassParseListener();
            Domain inDomain = new Domain(inputNet, pl);
            inDomain = ExpandDomain(inDomain, dType, numExpansions);
            inDomain.SaveAsNet(outputNet);
        }

        /// <summary>
        /// Expands the domain
        /// </summary>
        /// <param name="inputDomain">Input domain</param>
        /// <param name="dType">Domain Type</param>
        /// <param name="numExpansions">Number of extra time steps to add</param>
        /// <returns>The expanded domain</returns>
        public Domain ExpandDomain(Domain inputDomain, DomainType dType, int numExpansions)
        {
            List<string> prefixes = inputDomain.GetPrefixes();
            int tLast = inputDomain.GetLastTimeStep();
            int tGap = GetGapBetweenSteps(inputDomain, prefixes, tLast);

            for (int i = 0; i < numExpansions; i++)
            {
                inputDomain = ExpandStep(inputDomain, prefixes, tLast, tGap, dType);
                tLast++;
            }

            return inputDomain;
        }

        /// <summary>
        /// Gets gaps between steps by averaging the gap for each prefix
        /// </summary>
        /// <param name="inputDomain">input domain</param>
        /// <param name="prefixes">Node name preixes</param>
        /// <param name="lastStep">last time step</param>
        /// <returns>The average gap between nodes</returns>
        private int GetGapBetweenSteps(Domain inputDomain, List<string> prefixes, int lastStep)
        {
            int gap = 0;
            int lastGap = 0;

            foreach (string prefix in prefixes)
            {
                Node endNodeA = inputDomain.GetNodeByName(prefix + lastStep);
                Node endNodeB = inputDomain.GetNodeByName(prefix + (lastStep - 1));
                int xa = endNodeA.GetPosition().X;

                if (endNodeB != null)
                {
                    int xb = endNodeB.GetPosition().X;
                    lastGap = (xa - xb);
                    gap += lastGap;
                }
                else
                {
                    gap += lastGap;
                }
            }

            return gap / prefixes.Count;
        }

        /// <summary>
        /// Add new nodes to the domain
        /// </summary>
        /// <param name="inputDomain">input domain</param>
        /// <param name="prefixes">node prefixes</param>
        /// <param name="lastStep">current final time step</param>
        /// <param name="gap">x gap between points</param>
        /// <returns>domain with added nodes</returns>
        private Domain AddNewNodes(Domain inputDomain, List<string> prefixes, int lastStep, int gap)
        {
            //Create new nodes
            foreach (string prefix in prefixes)
            {
                Node lastNode = inputDomain.GetNodeByName(prefix + lastStep);
                Node newNode;

                if (lastNode.GetCategory() != Node.Category.H_CATEGORY_UTILITY)
                {
                    DiscreteNode newNodeDiscrete;
                    DiscreteNode lastNodeDiscrete = (DiscreteNode)inputDomain.GetNodeByName(prefix + lastStep);

                    if (lastNode.GetCategory() == Node.Category.H_CATEGORY_CHANCE)
                    {
                        newNodeDiscrete = new LabelledDCNode(inputDomain);
                    }
                    else
                    {
                        newNodeDiscrete = new LabelledDDNode(inputDomain);
                    }

                    newNodeDiscrete.SetNumberOfStates(lastNodeDiscrete.GetNumberOfStates());

                    for (uint i = 0; i < newNodeDiscrete.GetNumberOfStates(); i++)
                    {
                        string label = lastNodeDiscrete.GetStateLabel(i);
                        newNodeDiscrete.SetStateLabel(i, label);
                    }

                    newNode = newNodeDiscrete;
                }
                else
                {
                    newNode = new UtilityNode(inputDomain);
                }

                newNode.SetName(prefix + (lastStep + 1));
                //Finally set position
                Point pos = lastNode.GetPosition();
                pos.X = pos.X + gap;
                newNode.SetPosition(pos);
            }

            return inputDomain;
        }

        /// <summary>
        /// Sets up links between nodes based on existintg links and domain type
        /// </summary>
        /// <param name="inputDomain">input domain</param>
        /// <param name="prefixes">node prefixes</param>
        /// <param name="lastStep">current last time step</param>
        /// <param name="dType">domain type</param>
        /// <returns>updated domain model</returns>
        private Domain SetupLinksBetweenNodes(Domain inputDomain, List<string> prefixes, int lastStep, DomainType dType)
        {
            foreach (string prefix in prefixes)
            {
                Node nodeA = inputDomain.GetNodeByName(prefix + lastStep);
                Node nodeB = inputDomain.GetNodeByName(prefix + (lastStep + 1));

                if (dType != DomainType.DID)
                {
                    foreach (Node nodeP in Enumerable.Reverse(nodeA.GetParents()))
                    {
                        //For each parent get the name of parent for the next time step
                        nodeB.AddParent(inputDomain.GetNodeByName(nodeP.GetNodePrefix() + (nodeP.GetNodeTimeStep() + 1)));
                    }
                }
                else
                {
                    foreach (Node nodeP in Enumerable.Reverse(nodeA.GetParents()))
                    {
                        //For each parent get the name of parent for the next time step
                        nodeB.AddParent(inputDomain.GetNodeByName(nodeP.GetNodePrefix() + (nodeP.GetNodeTimeStep() + 1)));
                        
                        //DIDs have the no forget architecture for decision nodes
                        if (nodeB.GetCategory() == Node.Category.H_CATEGORY_DECISION && nodeA.GetCategory() == Node.Category.H_CATEGORY_DECISION)
                        {
                            if (nodeA.GetNodeTimeStep() == 1)
                            {
                                nodeB.AddParent(nodeP);
                            }
                        }
                    }
                }
            }

            return inputDomain;
        }

        /// <summary>
        /// Copies tables between time steps
        /// </summary>
        /// <param name="inputDomain">input domain</param>
        /// <param name="prefixes">node prefixes</param>
        /// <param name="lastStep">current last time step</param>
        /// <param name="dType">domain type</param>
        /// <returns>updated domain model</returns>
        private Domain SetupTables(Domain inputDomain, List<string> prefixes, int lastStep, DomainType dType)
        {
            foreach (var prefix in prefixes)
            {
                Node nodeA = inputDomain.GetNodeByName(prefix + lastStep);
                Node nodeB = inputDomain.GetNodeByName(prefix + (lastStep + 1));

                if (dType != DomainType.DID)
                {
                    nodeB.GetTable().SetData(nodeA.GetTable().GetData());
                }
                else
                {
                    //Decision nodes have no forget so do not set tables for decision nodes
                    if (nodeB.GetCategory() != Node.Category.H_CATEGORY_DECISION && nodeA.GetCategory() != Node.Category.H_CATEGORY_DECISION)
                    {
                        nodeB.GetTable().SetData(nodeA.GetTable().GetData());
                    }
                }
            }

            return inputDomain;
        }
        
        /// <summary>
        /// Expands steps between nodes in 3 stages
        /// 1. Add new nodes using prefixes and time step counter
        /// 2. Create links between nodes based on links in previous time step
        /// 3. Copy tables between steps
        /// </summary>
        /// <param name="inputDomain">input domain to be expanded</param>
        /// <param name="prefixes">prefixes of all nodes in domain</param>
        /// <param name="lastStep">current final time step</param>
        /// <param name="gap">x gap between nodes</param>
        /// <returns>domain with added nodes and links</returns>
        private Domain ExpandStep(Domain inputDomain, List<string> prefixes, int lastStep, int gap, DomainType dType)
        {
            inputDomain = AddNewNodes(inputDomain, prefixes, lastStep, gap);
            inputDomain = SetupLinksBetweenNodes(inputDomain, prefixes, lastStep, dType);
            inputDomain = SetupTables(inputDomain, prefixes, lastStep, dType);
            return inputDomain;
        }
    }
}
